"""
Regularized evolution as described in:
Real, E., Aggarwal, A., Huang, Y., and Le, Q. V.
Regularized Evolution for Image Classifier Architecture Search.
In Proceedings of the Conference on Artificial Intelligence (AAAI’19)

The code is based one the original regularized evolution open-source implementation:
https://colab.research.google.com/github/google-research/google-research/blob/master/evolution/regularized_evolution_algorithm/regularized_evolution.ipynb

NOTE: This script has certain deviations from the original code owing to the search space of the benchmarks used:
1) The fitness function is not accuracy but error and hence the negative error is being maximized.
2) The architecture is a ConfigSpace object that defines the model architecture parameters.

Adaptions were made to make it compatible with the search spaces.
"""

import argparse
import collections
import copy
import os
import pickle
import random

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from nasbench import api

from nasbench_analysis.search_spaces.search_space_1 import SearchSpace1
from nasbench_analysis.search_spaces.search_space_2 import SearchSpace2
from nasbench_analysis.search_spaces.search_space_3 import SearchSpace3
from nasbench_analysis.utils import upscale_to_nasbench_format, INPUT, OUTPUT, CONV1X1, CONV3X3, \
    MAXPOOL3X3

sns.set_style('whitegrid')

Architecture = collections.namedtuple('Architecture', ['adjacency_matrix', 'node_list'])


class Model(object):
    """A class representing a model.

    It holds two attributes: `arch` (the simulated architecture) and `accuracy`
    (the simulated accuracy / fitness). See Appendix C for an introduction to
    this toy problem.

    In the real case of neural networks, `arch` would instead hold the
    architecture of the normal and reduction cells of a neural network and
    accuracy would be instead the result of training the neural net and
    evaluating it on the validation set.

    We do not include test accuracies here as they are not used by the algorithm
    in any way. In the case of real neural networks, the test accuracy is only
    used for the purpose of reporting / plotting final results.

    In the context of evolutionary algorithms, a model is often referred to as
    an "individual".

    Attributes:  (as in the original code)
      arch: the architecture as an int representing a bit-string of length `DIM`.
          As a result, the integers are required to be less than `2**DIM`. They
          can be visualized as strings of 0s and 1s by calling `print(model)`,
          where `model` is an instance of this class.
      accuracy:  the simulated validation accuracy. This is the sum of the
          bits in the bit-string, divided by DIM to produce a value in the
          interval [0.0, 1.0]. After that, a small amount of Gaussian noise is
          added with mean 0.0 and standard deviation `NOISE_STDEV`. The resulting
          number is clipped to within [0.0, 1.0] to produce the final validation
          accuracy of the model. A given model will have a fixed validation
          accuracy but two models that have the same architecture will generally
          have different validation accuracies due to this noise. In the context
          of evolutionary algorithms, this is often known as the "fitness".
    """

    def __init__(self):
        self.arch = None
        self.accuracy = None
        self.training_time = None

    def __str__(self):
        """Prints a readable version of this bitstring."""
        return '{0:b}'.format(self.arch)


def train_and_eval(config):
    adjacency_matrix, node_list = config.adjacency_matrix, config.node_list
    if type(search_space) == SearchSpace1 or type(search_space) == SearchSpace2:
        # Fill up adjacency matrix and node list with entries for unused nodes
        adjacency_matrix = upscale_to_nasbench_format(adjacency_matrix)
        node_list = [INPUT, *node_list, CONV1X1, OUTPUT]
    else:
        node_list = [INPUT, *node_list, OUTPUT]
    adjacency_list = adjacency_matrix.astype(np.int).tolist()
    model_spec = api.ModelSpec(matrix=adjacency_list, ops=node_list)
    nasbench_data = nasbench.query(model_spec)
    return nasbench_data['validation_accuracy'], nasbench_data['training_time']


def random_architecture():
    adjacency_matrix, node_list = search_space.sample_with_loose_ends()
    architecture = Architecture(adjacency_matrix=adjacency_matrix, node_list=node_list)
    return architecture


def mutate_arch(parent_arch):
    # Choose one of the three mutations
    mutation = np.random.choice(['identity', 'hidden_state_mutation', 'op_mutation'])

    adjacency_matrix, node_list = copy.deepcopy(parent_arch.adjacency_matrix), copy.deepcopy(parent_arch.node_list)
    if mutation == 'identity':
        return Architecture(adjacency_matrix=adjacency_matrix, node_list=node_list)
    elif mutation == 'hidden_state_mutation':
        # Pick one of the intermediate nodes in the graph (neither input nor output)
        if type(search_space) == SearchSpace1:
            # Node 1 has now choice for node 2 as it always has 2 parents
            low = 3
        else:
            low = 2
        random_node = np.random.randint(low=low, high=adjacency_matrix.shape[-1])
        # Pick one of the parents of the node
        parent_of_node_to_modify = np.random.choice(adjacency_matrix[:random_node, random_node].nonzero()[0])
        # Select a new parent for this node, (needs to be different from previous parent)
        new_parent_of_node = np.random.choice(np.argwhere(adjacency_matrix[:random_node, random_node] == 0).flatten())
        # Remove old parent from child
        adjacency_matrix[parent_of_node_to_modify, random_node] = 0
        # Add new parent to child
        adjacency_matrix[new_parent_of_node, random_node] = 1
        # Create new child config
        return Architecture(adjacency_matrix=adjacency_matrix, node_list=node_list)
    else:  # op_mutation
        OPS = [CONV3X3, CONV1X1, MAXPOOL3X3]
        op_idx_to_change = np.random.randint(len(node_list))
        # Remove current op on selected idx (because we want a new op)
        OPS.remove(node_list[op_idx_to_change])
        # Select one of the remaining ops
        new_op = np.random.choice(OPS)
        node_list[op_idx_to_change] = new_op
        return Architecture(adjacency_matrix=adjacency_matrix, node_list=node_list)


def regularized_evolution(cycles, population_size, sample_size):
    """Algorithm for regularized evolution (i.e. aging evolution).

    Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
    Classifier Architecture Search".

    Args:
      cycles: the number of cycles the algorithm should run for.
      population_size: the number of individuals to keep in the population.
      sample_size: the number of individuals that should participate in each
          tournament.

    Returns:
      history: a list of `Model` instances, representing all the models computed
          during the evolution experiment.
    """
    population = collections.deque()
    history = []  # Not used by the algorithm, only used to report results.

    # Initialize the population with random models.
    while len(population) < population_size:
        model = Model()
        model.arch = random_architecture()
        model.accuracy, model.training_time = train_and_eval(model.arch)
        population.append(model)
        history.append(model)

    # Carry out evolution in cycles. Each cycle produces a model and removes
    # another.
    while len(history) < cycles:
        # Sample randomly chosen models from the current population.
        sample = []
        while len(sample) < sample_size:
            # Inefficient, but written this way for clarity. In the case of neural
            # nets, the efficiency of this line is irrelevant because training neural
            # nets is the rate-determining step.
            candidate = random.choice(list(population))
            sample.append(candidate)

        # The parent is the best model in the sample.
        parent = max(sample, key=lambda i: i.accuracy)

        # Create the child model and store it.
        child = Model()
        child.arch = mutate_arch(parent.arch)
        child.accuracy, child.training_time = train_and_eval(child.arch)
        population.append(child)
        history.append(child)

        # Remove the oldest model.
        population.popleft()

    return history


def random_search(cycles):
    history = []
    for i in range(cycles):
        model = Model()
        model.arch = random_architecture()
        model.accuracy, model.training_time = train_and_eval(model.arch)
        history.append(model)
    return history


parser = argparse.ArgumentParser()
parser.add_argument('--run_id', default=0, type=int, nargs='?', help='unique number to identify this run')
parser.add_argument('--algorithm', default='RE', choices=['RE', 'RS'])
parser.add_argument('--search_space', default="1", type=str, nargs='?', help='specifies the benchmark')
parser.add_argument('--n_iters', default=280, type=int, nargs='?', help='number of iterations for optimization method')
parser.add_argument('--output_path', default="./", type=str, nargs='?',
                    help='specifies the path where the results will be saved')
parser.add_argument('--data_dir', default="nasbench_analysis/nasbench_data/108_e/nasbench_full.tfrecord", type=str,
                    nargs='?', help='specifies the path to the nasbench data')
parser.add_argument('--pop_size', default=100, type=int, nargs='?', help='population size')
parser.add_argument('--sample_size', default=10, type=int, nargs='?', help='sample_size')
parser.add_argument('--seed', default=0, type=int, help='random seed')

args = parser.parse_args()
nasbench = api.NASBench(args.data_dir)
if args.search_space == "1":
    search_space = SearchSpace1()
elif args.search_space == "2":
    search_space = SearchSpace2()
elif args.search_space == "3":
    search_space = SearchSpace3()
else:
    raise ValueError('Unknown search space')

for seed in range(6):
    np.random.seed(seed)
    output_path = os.path.join(args.output_path, "regularized_evolution")
    os.makedirs(os.path.join(output_path), exist_ok=True)

    # Set random_seed
    if args.algorithm == 'RE':
        history = regularized_evolution(
            cycles=args.n_iters, population_size=args.pop_size, sample_size=args.sample_size)
    else:
        history = random_search(cycles=args.n_iters)

    plt.figure()
    for idx, arch in enumerate(history):
        plt.scatter(idx, 1 - arch.accuracy - search_space.valid_min_error)

    plt.grid(True, which="both", ls="-", alpha=0.5)
    ax = plt.gca()
    ax.set_yscale('log')
    plt.xlabel('Iteration')
    plt.ylabel('Validation regret')
    plt.title('Search Space {}'.format(args.search_space))
    plt.savefig('test_re_ss_{}_seed_{}.pdf'.format(args.search_space, seed))
    plt.close()
    fh = open(os.path.join(output_path,
                           'algo_{}_{}_ssp_{}_seed_{}.obj'.format(args.algorithm, args.run_id, args.search_space,
                                                                  seed)), 'wb')
    pickle.dump(history, fh)
    fh.close()
